# Copyright 2024 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
import torch
from manager import BaseManager
from sklearn.metrics import auc, roc_auc_score, roc_curve

"""SLModel

"""
import logging
import math
import os
from typing import Callable, Dict, Iterable, List, Tuple, Union

from multiprocess import cpu_count
from tensorflow.keras.optimizers import Adam
from tqdm import tqdm

import secretflow as sf
from secretflow.data.base import Partition
from secretflow.data.horizontal import HDataFrame
from secretflow.data.ndarray import FedNdarray
from secretflow.data.vertical import VDataFrame
from secretflow.device import PYU, Device, reveal, wait
from secretflow.device.device.pyu import PYUObject
from secretflow.ml.nn.sl.agglayer.agg_layer import AggLayer
from secretflow.ml.nn.sl.agglayer.agg_method import AggMethod
from secretflow.ml.nn.sl.strategy_dispatcher import dispatch_strategy
from secretflow.security.privacy import DPStrategy
from secretflow.utils.random import global_random


def attach_ExPloitattack_to_splitnn_sf(
    cls, attack_criterion=None, target_client_index=0, device="cpu"
):
    class ExPloitAttackSplitNNWrapper_sf(cls):
        def __init__(self, *args, **kwargs):
            super(ExPloitAttackSplitNNWrapper_sf, self).__init__(*args, **kwargs)
            self.attack_criterion = attack_criterion
            self.target_client_index = target_client_index
            self.device = device

        @staticmethod
        def convert_to_ndarray(*data: List) -> Union[List[jnp.ndarray], jnp.ndarray]:
            def _convert_to_ndarray(hidden):
                # processing data
                if not isinstance(hidden, jnp.ndarray):
                    if isinstance(hidden, (tf.Tensor, torch.Tensor)):
                        hidden = jnp.array(hidden.numpy())
                    if isinstance(hidden, np.ndarray):
                        hidden = jnp.array(hidden)
                return hidden

            if isinstance(data, Tuple) and len(data) == 1:
                # The case is after packing and unpacking using PYU, a tuple of length 1 will be obtained, if 'num_return' is not specified to PYU.
                data = data[0]
            if isinstance(data, (List, Tuple)):
                return [_convert_to_ndarray(d) for d in data]
            else:
                return _convert_to_ndarray(data)

        @staticmethod
        def convert_to_tensor(hidden: Union[List, Tuple], backend: str):
            if backend == "tensorflow":
                if isinstance(hidden, (List, Tuple)):
                    hidden = [tf.convert_to_tensor(d) for d in hidden]

                else:
                    hidden = tf.convert_to_tensor(hidden)
            elif backend == "torch":
                if isinstance(hidden, (List, Tuple)):
                    hidden = [torch.Tensor(d) for d in hidden]
                else:
                    hidden = torch.Tensor(hidden)
            return hidden

        # def extract_intermidiate_gradient(self, outputs):
        #     self.backward_gradient(outputs.grad)
        #     return self.clients[self.target_client_index].grad_from_next_client
        def get_dataloader(self, X, batch_size, mode="train"):
            if mode == "train":
                data = np.load("./model/grads.npz")
                p_hats = data["p_hat"]
                grads = data["grads"]
                y_train = pd.read_csv("./model/ys.csv")
                y_train = y_train.iloc[:, 1].to_numpy()

                X_train = X
                X_train = tf.cast(X_train, dtype=tf.float32)
                y_train = tf.cast(y_train, dtype=tf.float32)
                p_hats = tf.cast(p_hats, dtype=tf.float32)
                grads = tf.cast(grads, dtype=tf.float32)
                dataset = tf.data.Dataset.from_tensor_slices(
                    (X_train, y_train, p_hats, grads)
                )
                dataloader = dataset.batch(batch_size)
            elif mode == "test":
                X_test = X
                y_test = pd.read_csv("./model/ys_test.csv")
                y_test = y_test.iloc[:, 1].to_numpy()

                X_test = tf.cast(X_test, dtype=tf.float32)
                y_test = tf.cast(y_test, dtype=tf.float32)
                dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test))
                dataloader = dataset.batch(batch_size)
            return dataloader

        def BCELoss(self, y_true, y_pred):
            binary_crossentropy = tf.keras.losses.binary_crossentropy(y_true, y_pred)
            return binary_crossentropy

        def GradLoss(self, model, grad_target, X, y_true):
            with tf.GradientTape() as tape:
                tape.watch(X)
                y_pred = model(X)
                y_pred = tf.reshape(y_pred, shape=(-1,))
                binary_crossentropy = self.BCELoss(y_true, y_pred)
                loss = tf.reduce_mean(binary_crossentropy)
            gradients = tape.gradient(loss, X)
            gradients = tf.cast(gradients, dtype=tf.float32)
            max_gradient_value = 5.0
            clipped_gradients = tf.clip_by_value(
                gradients, -max_gradient_value, max_gradient_value
            )
            gradient_loss_value = tf.norm(
                (grad_target - gradients * y_true.shape[0]), ord=1
            )
            return gradient_loss_value

        def KLLoss(self, y_pri, y_pred):
            eps = 1e-9
            y_pred_mean = tf.reduce_mean(
                tf.stack([1 - y_pred, y_pred], axis=1), axis=0, keepdims=True
            )
            kl_divergence = -y_pri * tf.math.log(y_pred_mean / (y_pri + eps) + eps)
            kl_divergence = tf.reduce_sum(kl_divergence)
            return kl_divergence

        def custom_loss(self, model, X, y_true, y_pred, y_pri, H_y, p_hat, grad_target):
            alpha_acc = 1
            alpha_grad = 0.001
            alpha_kl = 0.001
            H_y = tf.cast(H_y, dtype=tf.float32)
            y_pri = tf.cast(y_pri, dtype=tf.float32)
            loss_acc = self.BCELoss(y_true, y_pred) / H_y
            loss_grad = self.GradLoss(model, grad_target, X, y_true)
            loss_kl = self.KLLoss(y_pri, y_pred)
            # 返回总损失，可以根据需要调整损失项的权重
            total_loss = (
                alpha_acc * loss_acc + alpha_grad * loss_grad + alpha_kl * loss_kl
            )
            return total_loss

        def train_surrogate_model(self, model, epochs, args_dict, X_train):
            H_y = args_dict["H_y"]
            y_pri = args_dict["y_pri"]
            batch_size = args_dict["batch_size"]
            # 定义优化器
            optimizer = Adam(learning_rate=0.01)

            # 定义Dataloader
            dataloader = self.get_dataloader(X_train, batch_size)

            # 开始训练
            for epoch in range(epochs):
                total_loss = 0.0
                correct_predictions = 0
                total_samples = 0

                y_true_list = []
                y_pred_list = []
                for X, y, p_hat, grad_target in dataloader:
                    # 示例训练步骤：
                    with tf.GradientTape() as tape:
                        y_pred = model(X)
                        y_pred = tf.reshape(y_pred, shape=(-1,))
                        loss = self.custom_loss(
                            model, X, y, y_pred, y_pri, H_y, p_hat, grad_target
                        )

                    # 计算梯度并应用优化器
                    gradients = tape.gradient(loss, model.trainable_variables)
                    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
                    # 计算批次损失
                    total_loss += loss.numpy()

                    y_pred_list.append(y_pred)
                    # 计算批次准确度
                    predicted_labels = tf.round(y_pred)  # 四舍五入到0或1

                    y = tf.cast(y, dtype=tf.float32)

                    y_true_list.append(y)
                    # correct_predictions += tf.reduce_sum(tf.cast(tf.equal(predicted_labels, y), tf.int32))
                    correct_predictions += tf.reduce_sum(
                        tf.cast(predicted_labels == y, tf.int32)
                    )

                    total_samples += X.shape[0]
                # 计算平均损失和准确度
                average_loss = total_loss / total_samples

                # 计算AUC值
                auc_metric = tf.keras.metrics.AUC()
                auc_metric.update_state(y_true_list, y_pred_list)
                auc_value = auc_metric.result().numpy()
                accuracy = correct_predictions / total_samples

                # 打印训练周期结束时的损失和准确度
                print(
                    "==================================================================================================="
                )
                print(
                    f"Epoch {epoch+1}/{epochs}, Loss: {average_loss:.4f}, Accuracy: {accuracy:.4f}, AUC: {auc_value:.4f}"
                )
                print(
                    "==================================================================================================="
                )

                # 保存model
                model.save("./model/model.pth")

        def test_surrogate_model(self, model, X_test):
            batch_size = 1024
            # 定义Dataloader
            dataloader = self.get_dataloader(X_test, batch_size, mode="test")

            correct_predictions = 0
            total_samples = 0
            y_true_list = []
            y_pred_list = []
            for X, y in dataloader:
                y_pred = model(X)
                y_pred = tf.reshape(y_pred, shape=(-1,))
                y_pred_list.append(y_pred)
                # 计算批次准确度
                predicted_labels = tf.round(y_pred)  # 四舍五入到0或1
                y = tf.cast(y, dtype=tf.float32)
                y_true_list.append(y)
                correct_predictions += tf.reduce_sum(
                    tf.cast(predicted_labels == y, tf.int32)
                )
                total_samples += X.shape[0]
            # 计算AUC值
            auc_metric = tf.keras.metrics.AUC()

            y_true_list = np.concatenate(y_true_list)
            y_pred_list = np.concatenate(y_pred_list)

            auc_metric.update_state(y_true_list, y_pred_list)
            auc_value = auc_metric.result().numpy()
            accuracy = correct_predictions / total_samples
            # 打印训练周期结束时的损失和准确度
            print("====================================================")
            print(f"Test Accuracy: {accuracy:.4f}, AUC: {auc_value:.4f}")

        def exploit_attack(
            self,
            x: Union[
                VDataFrame,
                FedNdarray,
                List[Union[HDataFrame, VDataFrame, FedNdarray]],
            ],
            y: Union[VDataFrame, FedNdarray, PYUObject],
            batch_size=32,
            epochs=1,
            verbose=1,
            callbacks=None,
            test_data=None,
            shuffle=False,
            sample_weight=None,
            dp_spent_step_freq=None,
            dataset_builder: Callable[[List], Tuple[int, Iterable]] = None,
            audit_log_dir: str = None,
            audit_log_params: dict = {},
            random_seed: int = None,
            surrogate_builder=None,
            args_dict={},
        ):
            if random_seed is None:
                random_seed = global_random(self.device_y, 100000)

            params = locals()
            logging.info(f"SL Train Params: {params}")
            # sanity check
            assert (
                isinstance(batch_size, int) and batch_size > 0
            ), f"batch_size should be integer > 0"
            assert len(self._workers) == 2, "split learning only support 2 parties"
            if dp_spent_step_freq is not None:
                assert isinstance(dp_spent_step_freq, int) and dp_spent_step_freq >= 1

            # get basenet ouput num
            self.basenet_output_num = {
                device: reveal(worker.get_basenet_output_num())
                for device, worker in self._workers.items()
            }
            self.agglayer.set_basenet_output_num(self.basenet_output_num)
            # build dataset
            train_x, train_y = x, y
            if test_data is not None:
                logging.debug("validation_data provided")
                if len(test_data) == 2:
                    test_x, test_y = test_data
                    test_sample_weight = None
                else:
                    test_x, test_y, test_sample_weight = test_data
            else:
                test_x, test_y, test_sample_weight = None, None, None
            steps_per_epoch = self.handle_data(
                train_x,
                train_y,
                sample_weight=sample_weight,
                batch_size=batch_size,
                shuffle=shuffle,
                epochs=epochs,
                stage="train",
                random_seed=random_seed,
                dataset_builder=dataset_builder,
            )
            validation = False

            if test_x is not None and test_y is not None:
                test_steps = self.handle_data(
                    test_x,
                    test_y,
                    sample_weight=test_sample_weight,
                    batch_size=batch_size,
                    epochs=epochs,
                    stage="eval",
                    dataset_builder=dataset_builder,
                )

            self._workers[self.device_y].init_training(callbacks, epochs=epochs)
            [worker.on_train_begin() for worker in self._workers.values()]
            wait_steps = min(min(self.get_cpus()) * 2, 100)

            # 获取agg层的向量，作为surrogate model的训练输入
            for epoch in range(1):
                res = []
                agg_list = []
                report_list = []
                report_list.append(f"epoch: {epoch+1}/{epochs} - ")
                if verbose == 1:
                    pbar = tqdm(total=steps_per_epoch)
                self._workers[self.device_y].reset_metrics()
                [worker.on_epoch_begin(epoch) for worker in self._workers.values()]

                for step in range(0, steps_per_epoch):
                    if verbose == 1:
                        pbar.update(1)
                    hiddens = {}
                    self._workers[self.device_y].on_train_batch_begin(step=step)
                    for device, worker in self._workers.items():
                        # 1. Local calculation of basenet
                        hidden = worker.base_forward(stage="train")
                        # 2. The results of basenet are sent to fusenet
                        hiddens[device] = hidden
                    # do agglayer forward
                    agg_hiddens = self.agglayer.forward(hiddens, axis=0)
                    if isinstance(agg_hiddens, PYUObject):
                        agg_hiddens = [agg_hiddens]
                    agg_0 = sf.reveal(agg_hiddens[0]).hidden
                    agg_1 = sf.reveal(agg_hiddens[1]).hidden
                    agg_ = tf.concat([agg_0, agg_1], axis=1)

                    agg_list.append(agg_)

            agg_tensor = tf.concat(agg_list, axis=0)
            surrogate_model = surrogate_builder()
            self.train_surrogate_model(
                surrogate_model, epochs, args_dict, X_train=agg_tensor
            )

            # 测试
            self._workers[self.device_y].reset_metrics()
            test_agg_list = []
            for step in range(0, test_steps):
                hiddens = {}  # driver end
                for device, worker in self._workers.items():
                    hidden = worker.base_forward("eval")
                    hiddens[device] = hidden
                agg_hiddens = self.agglayer.forward(hiddens, axis=0)
                if isinstance(agg_hiddens, PYUObject):
                    agg_hiddens = [agg_hiddens]
                agg_0 = sf.reveal(agg_hiddens[0]).hidden
                agg_1 = sf.reveal(agg_hiddens[1]).hidden
                agg_ = tf.concat([agg_0, agg_1], axis=1)
                test_agg_list.append(agg_)
            test_agg_tensor = tf.concat(test_agg_list, axis=0)
            self.test_surrogate_model(surrogate_model, X_test=test_agg_tensor)

        def fit_save(
            self,
            x: Union[
                VDataFrame,
                FedNdarray,
                List[Union[HDataFrame, VDataFrame, FedNdarray]],
            ],
            y: Union[VDataFrame, FedNdarray, PYUObject],
            batch_size=32,
            epochs=1,
            verbose=1,
            callbacks=None,
            validation_data=None,
            shuffle=False,
            sample_weight=None,
            validation_freq=1,
            dp_spent_step_freq=None,
            dataset_builder: Callable[[List], Tuple[int, Iterable]] = None,
            audit_log_dir: str = None,
            audit_log_params: dict = {},
            random_seed: int = None,
            save_path: str = "",
        ):
            """Vertical split learning training interface

            Args:
                x: Input data. It could be:

                - VDataFrame: a vertically aligned dataframe.
                - FedNdArray: a vertically aligned ndarray.
                - List[Union[HDataFrame, VDataFrame, FedNdarray]]: list of dataframe or ndarray.

                y: Target data. It could be a VDataFrame or FedNdarray which has only one partition, or a PYUObject.
                batch_size: Number of samples per gradient update.
                epochs: Number of epochs to train the model
                verbose: 0, 1. Verbosity mode
                callbacks: List of `keras.callbacks.Callback` instances.
                validation_data: Data on which to validate
                shuffle: Whether shuffle dataset or not
                validation_freq: specifies how many training epochs to run before a new validation run is performed
                sample_weight: weights for the training samples
                dp_spent_step_freq: specifies how many training steps to check the budget of dp
                dataset_builder: Callable function, its input is `x` or `[x, y]` if y is set, it should return a
                    dataset.
                audit_log_dir: If audit_log_dir is set, audit model will be enabled
                audit_log_params: Kwargs for saving audit model, eg: {'save_traces'=True, 'save_format'='h5'}
                random_seed: seed for prg, will only affect dataset shuffle
            """
            if random_seed is None:
                random_seed = global_random(self.device_y, 100000)

            params = locals()
            logging.info(f"SL Train Params: {params}")
            # sanity check
            assert (
                isinstance(batch_size, int) and batch_size > 0
            ), f"batch_size should be integer > 0"
            assert isinstance(validation_freq, int) and validation_freq >= 1
            assert len(self._workers) == 2, "split learning only support 2 parties"
            assert isinstance(validation_freq, int) and validation_freq >= 1
            if dp_spent_step_freq is not None:
                assert isinstance(dp_spent_step_freq, int) and dp_spent_step_freq >= 1

            # get basenet ouput num
            self.basenet_output_num = {
                device: reveal(worker.get_basenet_output_num())
                for device, worker in self._workers.items()
            }
            self.agglayer.set_basenet_output_num(self.basenet_output_num)
            # build dataset
            train_x, train_y = x, y
            if validation_data is not None:
                logging.debug("validation_data provided")
                if len(validation_data) == 2:
                    valid_x, valid_y = validation_data
                    valid_sample_weight = None
                else:
                    valid_x, valid_y, valid_sample_weight = validation_data
            else:
                valid_x, valid_y, valid_sample_weight = None, None, None
            steps_per_epoch = self.handle_data(
                train_x,
                train_y,
                sample_weight=sample_weight,
                batch_size=batch_size,
                shuffle=shuffle,
                epochs=epochs,
                stage="train",
                random_seed=random_seed,
                dataset_builder=dataset_builder,
            )
            validation = False

            if valid_x is not None and valid_y is not None:
                validation = True
                valid_steps = self.handle_data(
                    valid_x,
                    valid_y,
                    sample_weight=valid_sample_weight,
                    batch_size=batch_size,
                    epochs=epochs,
                    stage="eval",
                    dataset_builder=dataset_builder,
                )

            self._workers[self.device_y].init_training(callbacks, epochs=epochs)
            [worker.on_train_begin() for worker in self._workers.values()]
            wait_steps = min(min(self.get_cpus()) * 2, 100)

            grads = []
            p_hats = []
            for epoch in range(epochs):
                res = []
                report_list = []
                report_list.append(f"epoch: {epoch+1}/{epochs} - ")
                if verbose == 1:
                    pbar = tqdm(total=steps_per_epoch)
                self._workers[self.device_y].reset_metrics()
                [worker.on_epoch_begin(epoch) for worker in self._workers.values()]

                for step in range(0, steps_per_epoch):
                    if verbose == 1:
                        pbar.update(1)
                    hiddens = {}
                    self._workers[self.device_y].on_train_batch_begin(step=step)
                    for device, worker in self._workers.items():
                        # 1. Local calculation of basenet
                        hidden = worker.base_forward(stage="train")
                        # 2. The results of basenet are sent to fusenet
                        hiddens[device] = hidden
                    # do agglayer forward
                    agg_hiddens = self.agglayer.forward(hiddens, axis=0)
                    if isinstance(agg_hiddens, PYUObject):
                        agg_hiddens = [agg_hiddens]
                    # 3. Fusenet do local calculates and return gradients
                    gradients = self._workers[self.device_y].fuse_net(*agg_hiddens)

                    if epoch == epochs - 1:
                        # 以下是获取gradient并保存
                        scatter_gradients = self.agglayer.backward(gradients)
                        worker_list = list(self.base_model_dict.keys())
                        client_device = worker_list[0]
                        server_device = worker_list[1]
                        grad_client = scatter_gradients[client_device]
                        grad_server = scatter_gradients[server_device]
                        grad_client_np = client_device(self.convert_to_ndarray)(
                            grad_client
                        )
                        grad_server_np = server_device(self.convert_to_ndarray)(
                            grad_server
                        )
                        grad_np = jnp.concatenate(
                            (
                                sf.reveal(grad_client_np)[0],
                                sf.reveal(grad_server_np)[0],
                            ),
                            axis=1,
                        )
                        # grad_np = sf.reveal(grad_client_np)[0]
                        grads.append(grad_np * batch_size)
                        # 以下是获取p_hat并保存
                        y_pred = self._workers[self.device_y].predict(*agg_hiddens)
                        p_hat = server_device(self.convert_to_ndarray)(y_pred)

                        p_hat_list = sf.reveal(p_hat)
                        for p_hat_single in p_hat_list:
                            p_hats.append(p_hat_single)

                    # In some strategies, we need to bypass the backpropagation step.
                    skip_gradient = False
                    if self.check_skip_grad:
                        skip_gradient = reveal(
                            self._workers[self.device_y].get_skip_gradient()
                        )

                    if not skip_gradient:
                        # do agglayer backward
                        scatter_gradients = self.agglayer.backward(gradients)
                        for device, worker in self._workers.items():
                            if device in scatter_gradients.keys():
                                worker.base_backward(scatter_gradients[device])

                    r_count = self._workers[self.device_y].on_train_batch_end(step=step)
                    res.append(r_count)
                    if (
                        self.dp_strategy_dict is not None
                        and dp_spent_step_freq is not None
                    ):
                        current_step = epoch * steps_per_epoch + step
                        if current_step % dp_spent_step_freq == 0:
                            privacy_device = {}
                            for device, dp_strategy in self.dp_strategy_dict.items():
                                privacy_dict = dp_strategy.get_privacy_spent(
                                    current_step
                                )
                                privacy_device[device] = privacy_dict
                    if len(res) == wait_steps:
                        wait(res)
                        res = []

                if epoch == epochs - 1:
                    p_hats = np.concatenate(p_hats)
                    grads = np.concatenate(grads)
                    np.savez(f"./model/grads.npz", p_hat=p_hats, grads=grads)

                if validation and epoch % validation_freq == 0:
                    # validation
                    self._workers[self.device_y].reset_metrics()
                    res = []
                    for step in range(0, valid_steps):
                        hiddens = {}  # driver end
                        for device, worker in self._workers.items():
                            hidden = worker.base_forward("eval")
                            hiddens[device] = hidden
                        agg_hiddens = self.agglayer.forward(hiddens, axis=0)
                        if isinstance(agg_hiddens, PYUObject):
                            agg_hiddens = [agg_hiddens]
                        metrics = self._workers[self.device_y].evaluate(*agg_hiddens)
                        res.append(metrics)
                        if len(res) == wait_steps:
                            wait(res)
                            res = []
                    wait(res)
                    self._workers[self.device_y].on_validation(metrics)

                    # save checkpoint
                    if audit_log_dir is not None:
                        epoch_base_model_path = {
                            device: os.path.join(
                                audit_log_dir,
                                "base_model",
                                device.party,
                                str(epoch),
                            )
                            for device in self._workers.keys()
                        }
                        epoch_fuse_model_path = os.path.join(
                            audit_log_dir,
                            "fuse_model",
                            str(epoch),
                        )
                        self.save_model(
                            base_model_path=epoch_base_model_path,
                            fuse_model_path=epoch_fuse_model_path,
                            is_test=self.simulation,
                            **audit_log_params,
                        )
                epoch_log = self._workers[self.device_y].on_epoch_end(epoch)
                for name, metric in reveal(epoch_log).items():
                    report_list.append(f"{name}:{metric} ")
                report = " ".join(report_list)
                if verbose == 1:
                    pbar.set_postfix_str(report)
                    pbar.close()
                if reveal(self._workers[self.device_y].get_stop_training()):
                    break

            history = self._workers[self.device_y].on_train_end()
            return reveal(history)

    return ExPloitAttackSplitNNWrapper_sf


class ExPloitAttackSplitNNManager_sf(BaseManager):
    def __init__(self, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs

    def attach(self, cls):
        return attach_ExPloitattack_to_splitnn_sf(cls, *self.args, **self.kwargs)
